import os
import json
import numpy as np
import networkx as nx
from tqdm import tqdm
from hovsg.graph.object import Object

def load_nav_graph(connectivity_dir, scan):
    ''' Load connectivity graph for each scan '''

    def distance(pose1, pose2):
        ''' Euclidean distance between two graph poses '''
        return ((pose1['pose'][3]-pose2['pose'][3])**2\
            + (pose1['pose'][7]-pose2['pose'][7])**2\
            + (pose1['pose'][11]-pose2['pose'][11])**2)**0.5

    with open(os.path.join(connectivity_dir, '%s_connectivity.json' % scan)) as f:
        G = nx.Graph()
        positions = {}
        data = json.load(f)
        for i,item in enumerate(data):
            if item['included']:
                for j,conn in enumerate(item['unobstructed']):
                    if conn and data[j]['included']:
                        positions[item['image_id']] = np.array([item['pose'][3],
                                item['pose'][7], item['pose'][11]])
                        assert data[j]['unobstructed'][i], 'Graph should be undirected'
                        G.add_edge(item['image_id'],data[j]['image_id'],weight=distance(item,data[j]))
        nx.set_node_attributes(G, values=positions, name='position')
    
    return G

class Nav3DSG:
    def __init__(self, scene):
        self.scene = scene

        obj_dir = f"../HOV-SG/data/scene_graphs/hm3dsem/{scene}/graph/objects"
        self.objs = self.load_objs(obj_dir)

        with open(f"../HOV-SG/vp2pos/vp2pos_{scene}.json", 'r') as f:
            self.vp2pos = json.load(f)
        self.connectivity_map, self.node2id = self.get_connectivity_map(scene)
    
    def load_objs(self, obj_dir):
        obj_files = [x for x in os.listdir(obj_dir) if x.endswith(".ply")]
        objs = []
        for obj_file in obj_files:
            obj_name = obj_file.split(".")[0]
            obj = Object(obj_name, None, obj_name)
            obj.load(obj_dir)
            objs.append(obj)
        print(f"Loaded {len(objs)} objects")
        return objs

    def get_connectivity_map(self, scan):
        G = load_nav_graph("../VLN-DUET/datasets/R2R/connectivity", scan)
        
        nodes = list(G.nodes())
        nodes2id = {node: i for i, node in enumerate(nodes)}

        edges = list(G.edges())
        nodes = [node for node in nodes if node in self.vp2pos]
        edges = [(u, v) for u, v in edges if u in self.vp2pos and v in self.vp2pos]
        connectivity_map = {nodes2id[node]: [round(x, 2) for x in self.vp2pos[node]] for node in nodes}
        return connectivity_map, nodes2id

    def get_objs_info(self):
        obj_infos = []
        for obj in self.objs:
            obj_center = np.array(obj.pcd.points).mean(axis=0)
            obj_info = {
                "id": obj.object_id,
                "position": [round(x, 2) for x in obj_center.tolist()],
            }
            obj_infos.append(obj_info)

        return obj_infos

    def generate_prompt(self, instruction, all_landmarks, lm_info, objs_info, start_pos):
        system_prompt = """[Task Background]
You are an advanced 3D environment understanding assistant. Your main objective is to interpret a language-based instruction describing an indoor environment and identify which candidate landmark best matches the specified target.

[Input Definitions]
1. 'Instruction': A natural language description involves spatial relationships (e.g., relative positions, distances) among landmarks.
2. 'Candidates': A list of candidates for all the landmarks mentioned in the instruction. Each landmark is one of the following:
    - Floor: No explicit candidate data is given. You must infer the positions belonging to this floor from the connectivity map.
    - Room: The candidates are a list of nodes in the connectivity map.
    - Object: Includes the object's unique identifier and the 3D coordinates of its center.
3. 'Target': The specific landmark name within the instruction that you must locate among the given candidates.
4. 'Connectivity Map': A representation of the environment's layout, including a list of key positions.
5. 'Start Position': The agent's initial 3D location, which may be referenced in the instruction.

[Coordinate System]
All 3D positions (x, y, z) follows the convention:
- x-axis: Left to right, increasing to the right.
- y-axis: Floor to ceiling, increasing upward.
- z-axis: Front to back, increasing forward.

[Output Requirements]
Analyze the provided information to decide which candidate is the correct match for the 'Target Landmark'. Consider all clues from the natural language description—particularly any spatial relationships—and compare them with the bounding boxes and 3D positions of the candidates.
Your output must identify a single candidate as the correct match. Format your answer as:
    'The correct candidate is <Candidate_ID>.'
"""

        landmarks, target = all_landmarks.split('\n')[:-1], all_landmarks.split('\n')[-1]
        landmarks =  [x.split('.')[1].strip() for x in landmarks]
        target = target.split(':')[1].strip().split('(')[0].strip()
        
        landmark_names = [x.split('(')[0].strip() for x in landmarks]
        landmark_types = [x.split('(')[1].split(')')[0].strip() for x in landmarks]

        prompt = f"""1. 'Instruction': {instruction}
2. 'Candidates':
"""
        for idx, (lm, lm_type) in enumerate(zip(landmark_names, landmark_types)):
            cands = lm_info[str(idx)]
            prompt += f"Landmark: {lm} ({lm_type})\n"
            if lm_type.lower() == 'floor':
                prompt += f"This landmark is a floor which need to be inferred from the connectivity map. Note that floor index starts from 1.\n"
            elif lm_type.lower() == 'room':
                for i, cand in enumerate(cands):
                    prompt += f"Candidate {i+1}: <id: {self.node2id[cand]}, position: {self.connectivity_map[self.node2id[cand]]}>\n"
            elif lm_type.lower() == 'object':
                for i, cand in enumerate(cands):
                    cand_info = [x for x in objs_info if x['id'] == cand][0]
                    prompt += f"Candidate {i+1}: {cand_info}\n"
                prompt += '\n'
        
        prompt += f"""3. 'Target': {target}
4. 'Connectivity Map': {self.connectivity_map}
5. 'Start Position': {start_pos}
"""     
        return system_prompt, prompt

    def save_llm_prompt(self, detection_results, reverie_data):
        objs_info = self.get_objs_info()
        llm_input = []
        for inst_id, lm_info in detection_results.items():
            start_pos = [round(x, 2) for x in self.vp2pos[reverie_data[inst_id[:-2]]['path'][0]]]
            system_prompt, user_content = self.generate_prompt(
                lm_info['instruction'],
                lm_info['landmarks'],
                lm_info,
                objs_info,
                start_pos
            )
            template = {
                "sample_id": inst_id,
                "messages": [],
                "system": system_prompt
            }
            template["messages"].append({"role": "user", "content": user_content})
            llm_input.append(template)
        
        return llm_input

target_scene = "zsNo4HB9uLZ"
split = "val_unseen"
agent = Nav3DSG(scene=target_scene)

with open(f"../VLN-DUET/datasets/REVERIE/annotations/REVERIE_{split}_enc.json", 'r') as f:
    reverie_data = json.load(f)
reverie_data = {x['id']: x for x in reverie_data}

with open(f"node_generation/qualified_args.json", 'r') as f:
    qualified_args = json.load(f)

save_path = os.path.join("node_generation/llm_input", split, target_scene)
os.makedirs(save_path, exist_ok=True)

for args in tqdm(qualified_args):
    with open(f"node_generation/detection_results/{split}/{target_scene}/{args}", 'r') as f:
        detection_results = json.load(f)
    llm_input = agent.save_llm_prompt(detection_results, reverie_data)
    
    with open(os.path.join(save_path, args), 'w') as f:
        json.dump(llm_input, f, indent=4)